{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Layers\n",
    "\n",
    "One factor behind deep learning's success\n",
    "is the availability of a wide range of layers\n",
    "that can be composed in creative ways\n",
    "to design architectures suitable\n",
    "for a wide variety of tasks.\n",
    "For instance, researchers have invented layers\n",
    "specifically for handling images, text,\n",
    "looping over sequential data,\n",
    "performing dynamic programming, etc.\n",
    "Sooner or later you will encounter (or invent)\n",
    "a layer that does not exist yet in DJL.\n",
    "In these cases, you must build a custom layer.\n",
    "In this section, we show you how.\n",
    "\n",
    "## Layers without Parameters\n",
    "\n",
    "To start, we construct a custom layer (a Block) \n",
    "that does not have any parameters of its own. \n",
    "This should look familiar if you recall our \n",
    "introduction to DJL's `Block` in :numref:`sec_model_construction`. \n",
    "The following `CenteredLayer` class simply\n",
    "subtracts the mean from its input. \n",
    "To build it, we simply need to inherit \n",
    "from the `AbstractBlock` class and implement the `forward()` and `getOutputShapes()` methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CenteredLayer extends AbstractBlock {\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDList current = inputs;\n",
    "        // Subtract the mean from the input\n",
    "        return new NDList(current.head().sub(current.head().mean()));\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        // Output shape should be the same as input\n",
    "        return inputs;\n",
    "    } \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us verify that our layer works as intended by feeding some data through it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "CenteredLayer layer = new CenteredLayer();\n",
    "\n",
    "Model model = Model.newInstance(\"centered-layer\");\n",
    "model.setBlock(layer);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "NDArray input = manager.create(new float[]{1f, 2f, 3f, 4f, 5f});\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now incorporate our layer as a component\n",
    "in constructing more complex models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(Linear.builder().setUnits(128).build());\n",
    "net.add(new CenteredLayer());\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, input.getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As an extra sanity check, we can send random data \n",
    "through the network and check that the mean is in fact 0.\n",
    "Because we are dealing with floating point numbers, \n",
    "we may still see a *very* small nonzero number\n",
    "due to quantization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(-0.07f, 0.07f, new Shape(4, 8));\n",
    "NDArray y = predictor.predict(new NDList(input)).singletonOrThrow();\n",
    "y.mean();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layers with Parameters\n",
    "\n",
    "Now that we know how to define simple layers,\n",
    "let us move on to defining layers with parameters\n",
    "that can be adjusted through training. \n",
    "This lets us tell DJL what we need to calculate gradients for.\n",
    "To automate some of the routine work,\n",
    "the `Parameter` class and the `ParameterList` \n",
    "provide some basic housekeeping functionality.\n",
    "In particular, they govern access, initialization, \n",
    "sharing, saving, and loading model parameters. \n",
    "This way, among other benefits, we will not need to write\n",
    "custom serialization routines for every custom layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have all the basic ingredients that we need\n",
    "to implement our own version of DJL's `Linear` layer. \n",
    "Recall that this layer requires two parameters:\n",
    "one for weight and one for bias. \n",
    "In this implementation, we bake in the ReLU activation as a default.\n",
    "In the constructor, `inUnits` and `outUnits`\n",
    "denote the number of inputs and outputs, respectively."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We instantiate a new `Parameter` by calling its constructor and passing in\n",
    "a name, a reference to the block it is to be associated with, and its type which\n",
    "we can set from `ParameterType`.\n",
    "Then we call `addParameter()` in our `Linear`'s constructor \n",
    "with the newly instantiated `Parameter` and its respective `Shape`.\n",
    "We do this for both weight and bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "19"
    }
   },
   "outputs": [],
   "source": [
    "class MyLinear extends AbstractBlock {\n",
    "\n",
    "    private Parameter weight;\n",
    "    private Parameter bias;\n",
    "    \n",
    "    private int inUnits;\n",
    "    private int outUnits;\n",
    "\n",
    "    // outUnits: the number of outputs in this layer \n",
    "    // inUnits: the number of inputs in this layer\n",
    "    public MyLinear(int outUnits, int inUnits) {\n",
    "        this.inUnits = inUnits;\n",
    "        this.outUnits = outUnits;\n",
    "        weight = addParameter(\n",
    "            Parameter.builder()\n",
    "                .setName(\"weight\")\n",
    "                .setType(Parameter.Type.WEIGHT)\n",
    "                .optShape(new Shape(inUnits, outUnits))\n",
    "                .build());\n",
    "        bias = addParameter(\n",
    "            Parameter.builder()\n",
    "                .setName(\"bias\")\n",
    "                .setType(Parameter.Type.BIAS)\n",
    "                .optShape(new Shape(outUnits))\n",
    "                .build());\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDArray input = inputs.singletonOrThrow();\n",
    "        Device device = input.getDevice();\n",
    "        // Since we added the parameter, we can now access it from the parameter store\n",
    "        NDArray weightArr = parameterStore.getValue(weight, device, false);\n",
    "        NDArray biasArr = parameterStore.getValue(bias, device, false);\n",
    "        return relu(linear(input, weightArr, biasArr));\n",
    "    }\n",
    "    \n",
    "    // Applies linear transformation\n",
    "    public static NDArray linear(NDArray input, NDArray weight, NDArray bias) {\n",
    "        return input.dot(weight).add(bias);\n",
    "    }\n",
    "    \n",
    "    // Applies relu transformation\n",
    "    public static NDList relu(NDArray input) {\n",
    "        return new NDList(Activation.relu(input));\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        return new Shape[]{new Shape(outUnits, inUnits)};\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we instantiate the `MyLinear` class \n",
    "and access its model parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// 5 units in -> 3 units out\n",
    "MyLinear linear = new MyLinear(3, 5); \n",
    "var params = linear.getParameters();\n",
    "for (Pair<String, Parameter> param : params) {\n",
    "    System.out.println(param.getKey());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us initialize and test our `Linear`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "20"
    }
   },
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(0, 1, new Shape(2, 5));\n",
    "\n",
    "linear.initialize(manager, DataType.FLOAT32, input.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"my-linear\");\n",
    "model.setBlock(linear);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also construct models using custom layers.\n",
    "Once we have that we can use it just like the built-in dense layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(0, 1, new Shape(2, 64));\n",
    "\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(new MyLinear(8, 64)); // 64 units in -> 8 units out\n",
    "net.add(new MyLinear(1, 8)); // 8 units in -> 1 unit out\n",
    "net.initialize(manager, DataType.FLOAT32, input.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"lin-reg-custom\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "* We can design custom layers via the Block class. This allows us to define flexible new layers that behave differently from any existing layers in the library.\n",
    "* Once defined, custom layers can be invoked in arbitrary contexts and architectures.\n",
    "* Blocks can have local parameters, which are stored in a `LinkedHashMap<String, Parameter>` object in each `parameters` attribute.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Design a layer that learns an affine transform of the data.\n",
    "1. Design a layer that takes an input and computes a tensor reduction, \n",
    "   i.e., it returns $y_k = \\sum_{i, j} W_{ijk} x_i x_j$.\n",
    "1. Design a layer that returns the leading half of the Fourier coefficients of the data. Hint: look up `Fast Fourier Transform`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
